import os

import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, TensorDataset, ConcatDataset, Dataset
from sklearn import datasets as sk_ds
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn import preprocessing
import numpy as np
import pandas as pd
from torchvision.transforms.v2.functional import normalize

seed = 10

MNIST_MEAN = 0.1307
MNIST_STD = 0.3081

F_MNIST_MEAN = 0.3814
F_MNIST_STD = 0.3994

CIFAR_MEAN = (0.4915, 0.4823, .4468)
CIFAR_STD = (0.2470, 0.2435, 0.2616)

def data_generator(args, device):
    if args.dataset == 'mnist':
        return original_load_mnist(args.batch_size, device)



    if args.dataset == 'fashion_mnist':
        return fashion_mnist(args.batch_size,device)

    if args.dataset == 'synthetic':
        features = 20  # dimensions
        classes = 5
        factor = 0.2
        total_size = 6000
        test_size = int(total_size * factor)
        train_size = total_size - test_size
        input_dim = features
        output_dim = classes


        x_total, y_total = sk_ds.make_classification(
            n_features=features, n_redundant=0, n_informative=features,
            n_classes=classes, n_samples=total_size, random_state=seed, class_sep=3,
            shuffle=True
        )

        # scaler = StandardScaler()
        # scaler.fit(x_total)
        # x_total = scaler.transform(x_total)

        # minmax_scaler = MinMaxScaler()
        # x_total = minmax_scaler.fit_transform(x_total)

        # encoder = LabelBinarizer()
        # y_total = encoder.fit_transform(y_total)
        x_train, x_test, y_train, y_test = train_test_split(x_total, y_total, test_size=0.2, random_state=seed)

        # data augmentation
        # augmentation = x_train + (np.random.normal(0, 1, x_train.shape) * 0.0001)
        # x_train = np.vstack((x_train, augmentation))
        # y_train = np.hstack((y_train, y_train))

        X_train_tensor = torch.tensor(x_train, dtype=torch.float32)
        y_train_tensor = torch.tensor(y_train, dtype=torch.long)

        X_test_tensor = torch.tensor(x_test, dtype=torch.float32)
        y_test_tensor = torch.tensor(y_test, dtype=torch.long)
        train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
        test_dataset = TensorDataset(X_test_tensor, y_test_tensor)
        # torch.manual_seed(seed)
        train_l = DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=False)
        test_l = DataLoader(dataset=test_dataset, batch_size=args.batch_size, shuffle=False)



    if args.dataset == "adults":
        num_train = 40000
        labels, data = load_adult_data()
        feature_train, feature_test, label_train, label_test = train_test_split(data, labels, train_size=num_train)
        x_train = torch.from_numpy(feature_train).to(dtype=torch.float).to(device=device)
        y_train = torch.from_numpy(label_train).to(dtype=torch.int64).to(device=device)

        x_aug = x_train.clone() + torch.randn_like(x_train) * 0.01
        x_train = torch.cat((x_train,x_aug), dim=0)
        y_train = torch.cat((y_train,y_train.clone()), dim=0)


        x_test = torch.from_numpy(feature_test).to(dtype=torch.float).to(device=device)
        y_test = torch.from_numpy(label_test).to(dtype=torch.int64).to(device=device)



        x_train = torch.nn.functional.normalize(x_train,dim=1,p=2)
        x_test = torch.nn.functional.normalize(x_test,dim=1,p=2)

        train_dataset = TensorDataset(x_train, y_train)
        test_dataset = TensorDataset(x_test, y_test)

        # torch.manual_seed(seed)

        train_l = DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=False)
        test_l = DataLoader(dataset=test_dataset, batch_size=args.batch_size, shuffle=False)
        input_dim = data.shape[1]
        output_dim = 2


    if args.dataset == "adults_sgd":
        num_train = 40000
        labels, data = load_adult_data()
        feature_train, feature_test, label_train, label_test = train_test_split(data, labels, train_size=num_train)
        x_train = torch.from_numpy(feature_train).to(dtype=torch.float).to(device=device)
        y_train = torch.from_numpy(label_train).to(dtype=torch.int64).to(device=device)

        x_aug = x_train.clone() + torch.randn_like(x_train) * 0.01
        x_train = torch.cat((x_train, x_aug), dim=0)
        y_train = torch.cat((y_train, y_train.clone()), dim=0)

        x_test = torch.from_numpy(feature_test).to(dtype=torch.float).to(device=device)
        y_test = torch.from_numpy(label_test).to(dtype=torch.int64).to(device=device)

        x_train = torch.nn.functional.normalize(x_train, dim=1, p=2)
        x_test = torch.nn.functional.normalize(x_test, dim=1, p=2)

        train_dataset = TensorDataset(x_train, y_train)
        test_dataset = TensorDataset(x_test, y_test)

        # torch.manual_seed(seed)

        train_l = DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=False)
        test_l = DataLoader(dataset=test_dataset, batch_size=args.batch_size, shuffle=False)
        input_dim = data.shape[1]
        output_dim = 2

    return train_l, test_l, input_dim, output_dim




def original_load_mnist(batch_size, device):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((MNIST_MEAN,), (MNIST_STD,))
    ])

    transform_aug = transforms.Compose([
        transforms.RandomRotation(10),  # 随机旋转
        # transforms.RandomAffine(degrees=0, translate=(0.2, 0.2)),
        transforms.ToTensor(),
        transforms.Normalize((MNIST_MEAN,), (MNIST_STD,))
    ])

    train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
    test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)

    aug = False

    if aug:
        train_dataset_aug = datasets.MNIST(root='./data', train=True, transform=transform_aug, download=True)
        train_combine = ConcatDataset([train_dataset, train_dataset_aug])
    else:
        train_combine = train_dataset

    input_dim = 28 * 28
    output_dim = 10

    train_l = torch.utils.data.DataLoader(dataset=train_combine, batch_size=batch_size, shuffle=False)
    test_l = torch.utils.data.DataLoader(dataset=test_dataset, shuffle=False)
    input_dim = 28 * 28
    output_dim = 10
    for batch in train_l:
        inputs, labels = batch[0].to(device), batch[1].to(device)
    for batch in test_l:
        inputs, labels = batch[0].to(device), batch[1].to(device)

    return train_l, test_l, input_dim, output_dim



def load_adult_data():
    header = ['age', 'workclass', 'fnlwgt', 'education', 'education-num', 'marital-status', 'occupation',
              'relationship', 'race', 'sex', 'capital-gain',
              'capital-loss', 'hours-per-week', 'native-country', 'salary']
    df = pd.read_csv("./data/adults.csv", index_col=False, skipinitialspace=True, header=None, names=header)
    df = df.replace('?', np.nan)
    df.dropna(inplace=True)
    categorical_columns = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'sex',
                           'native-country']
    normalize_columns = ['age', 'fnlwgt', 'capital-gain', 'capital-loss', 'hours-per-week']
    label_column = ['salary']

    def convert_to_int(columns):
        for column in columns:
            unique_values = df[column].unique().tolist()
            dic = {}
            for indx, val in enumerate(unique_values):
                dic[val] = indx
            df[column] = df[column].map(dic).astype(int)
            print(column + " done!")

    def convert_to_onehot(data, columns):
        dummies = pd.get_dummies(data[columns])
        data = data.drop(columns, axis=1)
        data = pd.concat([data, dummies], axis=1)
        return data

    def show_unique_values(columns):
        for column in columns:
            uniq = df[column].unique().tolist()
            print(column + " has " + str(len(uniq)) + " values" + " : " + str(uniq))

    convert_to_int(label_column)
    df = convert_to_onehot(df, categorical_columns)
    show_unique_values(label_column)

    def adult_normalize(columns):
        scaler = preprocessing.StandardScaler()
        df[columns] = scaler.fit_transform(df[columns])

    adult_normalize(normalize_columns)

    label = df["salary"].to_numpy().astype(float)
    data = df.drop("salary", axis=1).to_numpy().astype(float)

    return label, data

def fashion_mnist(batch_size, device):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((MNIST_MEAN,), (MNIST_STD,))
    ])

    transform_aug = transforms.Compose([
        transforms.RandomRotation(10),  # 随机旋转
        # transforms.RandomAffine(degrees=0, translate=(0.2, 0.2)),
        transforms.ToTensor(),
        transforms.Normalize((MNIST_MEAN,), (MNIST_STD,))
    ])

    train_dataset = datasets.FashionMNIST(root='./data', train=True, transform=transform, download=True)
    test_dataset = datasets.FashionMNIST(root='./data', train=False, transform=transform)

    aug = True

    if aug:
        train_dataset_aug = datasets.FashionMNIST(root='./data', train=True, transform=transform_aug, download=True)
        train_combine = ConcatDataset([train_dataset, train_dataset_aug])
    else:
        train_combine = train_dataset

    input_dim = 28 * 28
    output_dim = 10

    train_l = torch.utils.data.DataLoader(dataset=train_combine, batch_size=batch_size, shuffle=False)
    test_l = torch.utils.data.DataLoader(dataset=test_dataset, shuffle=False)
    input_dim = 28 * 28
    output_dim = 10
    for batch in train_l:
        inputs, labels = batch[0].to(device), batch[1].to(device)
    for batch in test_l:
        inputs, labels = batch[0].to(device), batch[1].to(device)

    return train_l, test_l, input_dim, output_dim

def take_first(dataset: TensorDataset, num_to_keep: int):
    return TensorDataset(dataset.tensors[0][0:num_to_keep], dataset.tensors[1][0:num_to_keep])


def iterate_dataset(dataset: Dataset, batch_size: int, device):
    """Iterate through a dataset, yielding batches of data."""
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    # for (batch_X, batch_y) in loader:
    #     yield batch_X.to(device=device), batch_y.to(device=device)
    return loader

